import random
import os
import math
import torch
import torch.nn as nn
from utils import *
from layers import *
from transformers import BertModel

class BubbleEmbedBase(nn.Module):
    def __init__(self, args):
        super(BubbleEmbedBase, self).__init__()
        
        self.args = args
        self.pre_train_model = self.__load_pre_trained__()
        self.dropout = nn.Dropout(args.dropout)
        self.projection_center = MLP_VEC(input_dim=768, hidden=self.args.hidden, output_dim=self.args.embed_size, num_hidden_layers=0)
        self.projection_radius = MLP_VEC(input_dim=768, hidden=self.args.hidden, output_dim=1, num_hidden_layers=0)

    def __load_pre_trained__(self):
        model = BertModel.from_pretrained("bert-base-uncased")
        print("Model Loaded!")
        return model

    def projection_bubble(self, encode_inputs):
        cls = self.pre_train_model(**encode_inputs)
        cls = self.dropout(cls[0][:, 0, :])
        center = self.projection_center(cls)
        radius = torch.exp(self.projection_radius(cls)).clamp_min(1e-8)
        if(self.args.dataset == "movielens"):
            radius = torch.clamp(radius,max=(math.sqrt(self.args.embed_size)/2))
        return center, radius
    
    def forward(self, encode_inputs, flag="train"):
        center, delta = self.projection_bubble(encode_inputs)
        return center, delta